import torch
import torch.nn as nn
from Models.Backbone import BackboneNet, Convention


class YOLOFeature(BackboneNet):
    def __init__(self,classes_num = 20):
        """
        原文说的就是20个所以咱们也就来个20
        :param classes_num:
        """
        super(YOLOFeature,self).__init__()
        self.classes_num = classes_num

        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.linear = nn.Linear(1024, self.classes_num)

    def forward(self, x):
        x = self.Conv_Feature(x)
        x = self.Conv_Semanteme(x)
        x = self.avg_pool(x)
        x = x.permute(0, 2, 3, 1)
        x = torch.flatten(x, start_dim=1, end_dim=3)
        x = self.linear(x)
        return x
    """
    初始化权重
    """
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight.data)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.kaiming_normal_(m.weight.data)
                m.bias.data.zero_()
            elif isinstance(m, Convention):
                m.weight_init()



